"""
Module contains tools to manipulate matplotlib graphs.
"""
import matplotlib as mpl
import numpy as np
import pandas as pd
import pandas_datareader.data as web
from matplotlib import rcParams
from matplotlib.ticker import FuncFormatter

from cat_rfs.code.utils.tools import find_frequency


def add_thousand_separators(ax, axis='y'):
    """Format graph axis to have thousands separated by commas."""
    assert axis in {'y', 'x', 'both'}

    def _format(ax1):
        if isinstance(ax1, np.ndarray):
            for elem in ax1:
                _format(elem)
        else:
            if axis in {'y', 'both'}:
                ax1.yaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
            if axis in {'x', 'both'}:
                ax1.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
        return ax1

    ax = _format(ax)

    return ax


def convert_to_percentages(ax, axis='y', digits=1):
    """Convert axis units from decimals to percentages."""
    assert axis in {'y', 'x', 'both'}

    def _format(ax1):
        if isinstance(ax1, np.ndarray):
            for elem in ax1:
                _format(elem)
        else:
            if axis in {'y', 'both'}:
                ax1.yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.{}%}'.format(y, digits)))
            if axis in {'x', 'both'}:
                ax1.xaxis.set_major_formatter(FuncFormatter(lambda x, _: '{:.{}%}'.format(x, digits)))
        return ax1

    ax = _format(ax)

    return ax


def add_footnote(ax, text):
    """Add footnote text below a graph."""
    if rcParams.get('savefig.bbox') == 'tight':
        ax.annotate(text, (0, 0), (0, -0.15), xycoords='axes fraction', va='top')
    else:
        ax.annotate(text, (0, 0), (0, -0.20), xycoords='axes fraction', va='top')
    return ax


def add_nber_recessions(ax, footnote=False, color='gray'):
    """Fetch NBER recession date indicators online and add shaded areas to input graph."""
    start_date = pd.to_datetime('12-31-1900')
    end_date = pd.to_datetime('12-31-2050')
    xvals = ax.lines[0].get_xdata()

    nber_df = web.DataReader('USREC', 'fred', start=start_date, end=end_date)

    if type(xvals[0]) == pd._libs.tslibs.period.Period:
        nber_df.index = nber_df.index.to_period(xvals[0].freqstr)
    elif type(xvals) == np.ndarray:
        if type(freqstr := find_frequency(pd.Series(xvals))) == str:
            nber_df.index = nber_df.index.to_period(freqstr)
        else:
            print('Unsupported axis type - Try Period or Timestamp')
            return
    elif type(xvals[0]) == pd._libs.tslib.Timestamp:
        print('TIMESTAMP INDEX FUNCTIONALITY NOT TESTED')
        nber_df.reindex(xvals, fill='pad')
    else:
        print('Unsupported axis type - Try Period or Timestamp')
        return

    nber_df = nber_df.reset_index()

    reclist = []
    while True:
        try:
            start = nber_df[nber_df['USREC'] == 1].iloc[0, 0]
            nber_df = nber_df[nber_df['DATE'] > start]
        except IndexError:
            break
        try:
            end = nber_df[nber_df['USREC'] == 0].iloc[0, 0]
            nber_df = nber_df[nber_df['DATE'] > end]
            reclist.append([start, end])
        except IndexError:
            reclist.append([start, nber_df['DATE'].max()])
            break

    for i in reclist:
        ax.axvspan(i[0], i[1], alpha=0.2, color=color)

    if footnote is True:
        add_footnote(ax, 'Shaded areas indicate U.S. recessions')

    return ax
